package danran.dbapi.filter;

import com.alibaba.fastjson.JSON;
import danran.dbapi.common.ResponseDto;
import danran.dbapi.service.IPService;
import danran.dbapi.utils.IPUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Map;

/**
 * @Classname ApiFilter
 * @Description TODO
 * @Date 2022/1/13 18:04
 * @Created by RanCoder
 */
@Component
public class ApiFilter implements Filter {
    private static Logger logger = LoggerFactory.getLogger(ApiFilter.class);

    @Autowired
    private IPService ipService;

    @Override
    public void init(FilterConfig filterConfig) throws ServletException {

    }

    @Override
    public void doFilter(ServletRequest request,
                         ServletResponse response,
                         FilterChain chain) throws IOException, ServletException {
        logger.info("Execute filter");
        HttpServletRequest req = (HttpServletRequest) request;
        HttpServletResponse rsp = (HttpServletResponse) response;

        String originIP = IPUtil.getIpAddress();

        String method = req.getMethod();

        rsp.setCharacterEncoding("UTF-8");
        rsp.setContentType("application/json; charset=utf-8");

        // 跨域设置
        rsp.setHeader("Access-Control-Allow-Origin", "*");
        rsp.setHeader("Access-Control-Allow-Credentials", "true");
        // 这里很重要，要不然js header不能跨域携带  Authorization属性
        rsp.setHeader("Access-Control-Allow-Headers", "Authorization");
        rsp.setHeader("Access-Control-Allow-Methods", "POST, GET, PUT, OPTIONS, DELETE");

        PrintWriter out = null;
        try {
            if (method.equals("OPTIONS")) {
                rsp.setStatus(HttpServletResponse.SC_OK);
                return;
            }
            if (!checkIP(originIP)) {
                out = rsp.getWriter();
                out.append(JSON.toJSONString(ResponseDto.fail("非法的ip (" + originIP + "), 禁止访问")));
            } else {
                chain.doFilter(request, response);
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private boolean checkIP(String originIP) {
        Map<String, String> map = ipService.detail();
        String status = map.get("status");
        if ("on".equals(status)) {
            String mode = map.get("mode");
            if ("black".equals(mode)) {
                String blackIP = map.get("blackIP");
                if (!ipService.check(mode, blackIP, originIP)) {
                    logger.warn("ip黑名单拦截");
                    return false;
                }
            } else if ("white".equals(mode)) {
                String whiteIP = map.get("whiteIP");
                if (!ipService.check(mode, whiteIP, originIP)) {
                    logger.warn("ip白名单检查不通过");
                    return false;
                }
            }
        }
        return true;
    }

    @Override
    public void destroy() {

    }
}
